import json
from collections import defaultdict, Counter
import argparse
import os
import sys
from glob import glob
import copy
from tqdm import tqdm
from datasets import load_dataset
import random

sys.set_int_max_str_digits(0)
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))

from data.apps import APPsWithFunctionName


def main():
    """
    This script takes the completions from GPT4, the corresponding inputs for worsening code, and the original dataset, to construct a new combined dataset,
    which contains the solutions from the policy model, as well as the synthesized negative codes for each plausible solution.
    :return:
    """
    parser = argparse.ArgumentParser()
    parser.add_argument("--worsen_file", type=str,
                        help="The file contains completion from the teacher model for worsening code"
                             "The inputs for this file are generated by `pp_worsen_inputs.py` script.")
    parser.add_argument("--completion_file", type=str,
                        help="The file contains the completion for each query.")
    # parser.add_argument("--completion_response_field", type=str, default="completion")
    parser.add_argument("--completion_problem_id_field", type=str, default="problem_id")
    parser.add_argument("--output_file", type=str)
    args = parser.parse_args()

    worsen_codes = [json.loads(line) for line in open(args.worsen_file).readlines()]
    print(f"Total number of worsen codes: {len(worsen_codes)}")

    if os.path.exists(args.completion_file):
        data = json.load(open(args.completion_file))
    else:
        data = []
        for file in glob(args.completion_file):
            data += json.load(open(file))

    p_id2response = {item[args.completion_problem_id_field]: item for item in data}

    p_id2pairs = defaultdict(list)
    for item in worsen_codes:
        if item["completion"]:
            try:
                completion = json.loads(item["completion"])
            except:
                print(f"Json parsing error: {item['completion']}")
                continue
            neg_code = completion["incorrect_program"]
            # if isinstance(neg_code, dict):
            #     print(json.dumps(completion, indent=2))

            if not isinstance(neg_code, str):
                print(f"Bad format: {neg_code}.")
                continue

            if not neg_code.strip():
                continue

            p_id, pred_id = item["id"].split("_")
            pred_id = int(pred_id[3:])
            p_id = int(p_id)

            response = p_id2response[p_id]
            pred = response["pred"][pred_id]
            assert pred in item["prompt"]

            p_id2pairs[p_id].append((pred, neg_code))

    outputs = []
    num_pairs = 0
    for p_id, pairs in p_id2pairs.items():
        pos = []
        neg = []
        for pred, neg_code in pairs:
            pos.append(pred)
            neg.append(neg_code)
        assert len(pos) == len(neg)
        num_pairs += len(pos)
        outputs.append({
            "problem_id": p_id,
            "pos": pos,
            "neg": neg,
        })

    print(f"Total number of outputs: {len(outputs)}")
    print(f"Total number of pairs: {num_pairs}")
    json.dump(outputs, open(args.output_file, "w"), indent=2, ensure_ascii=False)


if __name__ == '__main__':
    main()

"""
>>> python scripts/apps/worsen_gpt4_combine.py --worsen_file outputs/apps/critique/r2c.sft.train.0shot.tem1.0.n10.v1.1.worsen_4o_critic.s42.f2000.gpt4o.tem1.0.s42.n1.json_obj.jsonl ]
    --completion_file ../msranlpintern/reward_modeling/experiments/deepseek-coder-v1.5-ins.7b.apps.r2c.gpt4o.distil.A100.w8.v3.0.s42/apps/checkpoint-400/train.0shot.tem1.0.n10.v1.1.s43.json \
    --completion_problem_id_field id --output_file ../msranlpintern/reward_modeling/experiments/deepseek-coder-v1.5-ins.7b.apps.r2c.gpt4o.distil.A100.w8.v3.0.s42/apps/checkpoint-400/train.0shot.tem1.0.n10.v1.1.s43.gpt4o.worsen.f2000.json

>>> python scripts/apps/worsen_gpt4_combine.py \
  --worsen_file outputs/apps/critique/r2c.sft.train.0shot.tem1.0.n10.v1.1.worsen_4o_critic.s42.f100k.gpt4o.tem1.0.s42.n1.json_obj.jsonl \
  --completion_file "../msranlpintern/reward_modeling/experiments/deepseek-coder-v1.5-ins.7b.apps.r2c.gpt4o.distil.A100.w8.v3.0.s42/apps/checkpoint-400/train.0shot.tem1.0.n10.?-of-8.v1.1.json" \
  --completion_problem_id_field id \
  --output_file ../msranlpintern/reward_modeling/experiments/deepseek-coder-v1.5-ins.7b.apps.r2c.gpt4o.distil.A100.w8.v3.0.s42/apps/checkpoint-400/train.0shot.tem1.0.n10.v1.1.s43.gpt4o.worsen.f100k.json
>>> The above code is incorrect.

FIXME:
>>> python scripts/apps/worsen_gpt4_combine.py \
  --worsen_file outputs/apps/critique/r2c.sft.train.0shot.tem1.0.n10.v1.1.worsen_4o_critic.s42.f100k.gpt4o.tem1.0.s42.n1.json_obj.jsonl \
  --completion_file "../msranlpintern/reward_modeling/experiments/deepseek-coder-v1.5-ins.7b.apps.r2c.gpt4o.distil.A100.w8.v3.0.s42/apps/checkpoint-400/train.0shot.tem1.0.n10.v1.1.s43.json" \
  --completion_problem_id_field id \
  --output_file ../msranlpintern/reward_modeling/experiments/deepseek-coder-v1.5-ins.7b.apps.r2c.gpt4o.distil.A100.w8.v3.0.s42/apps/checkpoint-400/train.0shot.tem1.0.n10.v1.1.s43.gpt4o.worsen.f100k.fix0708.json

"""
